{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e5d5c4e3",
   "metadata": {},
   "source": [
    "# Causal Inference for Time Series"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da825502",
   "metadata": {},
   "source": [
    "Causal inference involves finding the effect of intervention on one set of variables, on another variable. For instance, if A->B->C. Then all the three variables may be correlated, but intervention on C, does not affect the values of B, since C is not a causal ancestor of of B. But on the other hand, interventions on A or B, both affect the values of C. \n",
    "\n",
    "While there are many different kinds of causal inference questions one may be interested in, we currently support two kinds-- Average Treatment Effect (ATE) and conditional ATE (CATE). In ATE, we intervene on one set of variables with a treatment value and a control value, and estimate the expected change in value of some specified target variable. Mathematically,\n",
    "\n",
    "$$\\texttt{ATE} = \\mathbb{E}[Y | \\texttt{do}(X=x_t)] - \\mathbb{E}[Y | \\texttt{do}(X=x_c)]$$\n",
    "\n",
    "where $\\texttt{do}$ denotes the intervention operation. In words, ATE aims to determine the relative expected difference in the value of $Y$ when we intervene $X$ to be $x_t$ compared to when we intervene $X$ to be $x_c$. Here $x_t$ and $x_c$ are respectively the treatment value and control value.\n",
    "\n",
    "CATE makes a similar estimate, but under some condition specified for a set of variables. Mathematically,\n",
    "\n",
    "$$\\texttt{CATE} = \\mathbb{E}[Y | \\texttt{do}(X=x_t), C=c] - \\mathbb{E}[Y | \\texttt{do}(X=x_c), C=c]$$\n",
    "\n",
    "where we condition on some set of variables $C$ taking value $c$. Notice here that $X$ is intervened but $C$ is not. \n",
    "\n",
    "While ATE and CATE estimate expectation over the population, **Counterfactuals** aim at estimating the effect of an intervention on a specific instance or sample. Suppose we have a specific instance of a system of random variables $(X_1, X_2,...,X_N)$ given by $(X_1=x_1, X_2=x_2,...,X_N=x_N)$, then in a counterfactual, we want to know the effect an intervention (say) $X_1=k$ would have had on some other variable(s) (say $X_2$), holding all the remaining variables fixed. Mathematically, this can be expressed as,\n",
    "\n",
    "$$\\texttt{Counterfactual} = X_2 | \\texttt{do}(X_1=k), X_3=x_3, X_4=4,\\cdots,X_N=x_N$$\n",
    "\n",
    "To understand how causal inference works in the case of time series, let's consider the following graph as an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ec161227",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from causalai.misc.misc import plot_graph\n",
    "from causalai.data.data_generator import DataGenerator\n",
    "\n",
    "    \n",
    "fn = lambda x:x\n",
    "coef = 1.\n",
    "sem = {\n",
    "        'a': [(('a', -1), coef, fn),], \n",
    "        'b': [(('a', -1), coef, fn), (('b', -1), coef, fn),],\n",
    "        'c': [(('c', -1), coef, fn), (('b', -1), coef, fn),],\n",
    "        'd': [(('c', -1), coef, fn), (('d', -1), coef, fn),]\n",
    "        }\n",
    "T = 2000\n",
    "data,var_names,graph_gt = DataGenerator(sem, T=T, seed=0)\n",
    "plot_graph(graph_gt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "927c2082",
   "metadata": {},
   "source": [
    "Given this graph with 4 variables-- a b, c and d, and some observational data in the form for a $T \\times 4$ \n",
    "matrix, suppose we want to estimate the causal effect of interventions of\n",
    "the variable b on variable d. The SCM for this graph takes the form:\n",
    "\n",
    "$$a[t] = f_a(a[t-1]) + n_a $$\n",
    "$$b[t] = f_b(a[t-1], b[t-1]) + n_b $$\n",
    "$$c[t] = f_c(b[t-1], c[t-1]) + n_c $$\n",
    "$$d[t] = f_d(c[t-1], d[t-1]) + n_d $$\n",
    "\n",
    "Here $n_x$ are noise terms. Then intervening the values of the variable b at each time step, i.e., $do(b[t])$ for every $t$, causally affects\n",
    "the values of $d$. This is because $d$ directly depends on $c$, and $c$ depends on $b$, thus there is an indirect \n",
    "causal effect. \n",
    "\n",
    "Notice that if we were to intervene both $a$ and $b$, the intervention of $a$ \n",
    "would not have any impact on $d$ because it is blocked by $b$, which is also intervened. On the other hand, if we \n",
    "were to intervene $c$ in addition to $b$, then the intervention of $b$ would not have any impact on $d$ because it \n",
    "would be blocked by $c$.\n",
    "\n",
    "Coming back to the example shown in the above graph, we have established that an intervention on the values of $b$ \n",
    "impacts the values of $d$. Now suppose we want to calculate the treatment effect (say ATE) of this intervention on \n",
    "$d$. For the purpose of this exposition, let's consider just one of the terms in the ATE formula above, since both \n",
    "the terms have the same form. Specifically, we want to calculate,\n",
    "\n",
    "$$\\mathbb{E}_t[d[t] | \\texttt{do}(b)]$$\n",
    "Conceptually, this is achieved by setting the value of $b[t]=v$ ($v$ is any desired value) in the observational data \n",
    "at every time step \n",
    "$t \\in \\{0,1,...,T\\}$, then starting from $t=0$ in the above equations, we iterate through these equations \n",
    "in the order $b[t]$, $c[t]$, and $d[t]$ (the causal order), for each time $t$. Notice that we do not need \n",
    "to evaluate the equation for $a[t]$ because the intervention does not affect its value at any time step, and\n",
    "therefore, it remains the same as the values in the given observational data. This saves computation. We would \n",
    "similarly have ignored any other variable during this computation if it was either not affected by the intervention, \n",
    "or if there was no causal path from that variable to the target variable $d$.\n",
    "Finally, to compute $\\mathbb{E}_t[d[t] | \\texttt{do}(b)]$, we simply average over the values of $d$ computed \n",
    "using this procedure for all time steps.\n",
    "\n",
    "Notice that we do not need \n",
    "to evaluate the equation for $a$ in this process because its value has on impact on $d$ once we intervene $b$. This saves computation. We would \n",
    "similarly have ignored any other variable during this computation if it was either not affected by the intervention, \n",
    "or if there was no causal path from that variable to the target variable $d$.\n",
    "\n",
    "\n",
    "Now that we have a conceptual understanding, we point out that in reality, the functions $f_x$ for $x \\in \\{b,c,d \\}$ are unknown in practice. In fact, given only observational data, we do not even know the causal graph as the one shown in the example above. Therefore, causal inference is treated as a two step process. First we estimate the causal graph using the observational data. We then use one of the various techniques to perform causal inference given both the observational data and the causal graph.\n",
    "\n",
    "## Causal Inferencne methods supported by CausalAI\n",
    "\n",
    "In our library, for time series data, we support our in-house **causal_path method** that simulates the conceptual process described above for causal inference.\n",
    "\n",
    "#### causal_path method (defaut)\n",
    "\n",
    "Conceptually, this method works in two steps. For illustration, let's use the causal graph shown above as our example.\n",
    "1. We train two models $P_{\\theta_1}(c[t]|c[t-1], b[t-1])$ and $P_{\\theta_2}(d[t]|d[t-1], c[t-1])$ to predict $c[t]$ from $c[t-1], b[t-1]$, and $d[t]$ from $d[t-1], c[t-1]$, using the observational data. We have not used the intervention information in this step.\n",
    "2. we set the value of $b[t]=v$ ($v$ is the desired intervention value) for all the time steps in the observational data, then traverse the causal graph\n",
    "in the order $b$, $c$, and $d$ (the causal order), for each observation. For each of the nodes c and d, we use the corresponding trained models $P_{\\theta_1}(c[t]|c[t-1], b[t-1])$ and $P_{\\theta_2}(d[t]|d[t-1], c[t-1])$ as proxies for the unknown functions $f_c$ and $f_d$, and follow the steps described above to estimate the causal effect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fbfdf106",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "import pickle as pkl\n",
    "import time\n",
    "from functools import partial\n",
    "\n",
    "from causalai.misc.misc import plot_graph\n",
    "from causalai.data.data_generator import DataGenerator, ConditionalDataGenerator\n",
    "from causalai.models.time_series.causal_inference import CausalInference\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "\n",
    "def define_treatments(name, t,c):\n",
    "    treatment = dict(var_name=name,\n",
    "                    treatment_value=t,\n",
    "                    control_value=c)\n",
    "    return treatment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61cfda62",
   "metadata": {},
   "source": [
    "## Continuous Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eca19870",
   "metadata": {},
   "source": [
    "### Average Treatment Effect (ATE)\n",
    "For this example, we will use synthetic data that has linear dependence among data variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "147c9301",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': [],\n",
       " 'b': [('a', -1), ('f', -1)],\n",
       " 'c': [('b', -2), ('f', -2)],\n",
       " 'd': [('b', -4), ('g', -1)],\n",
       " 'e': [('f', -1)],\n",
       " 'f': [],\n",
       " 'g': []}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fn = lambda x:x\n",
    "coef = 0.1\n",
    "sem = {\n",
    "        'a': [], \n",
    "        'b': [(('a', -1), coef, fn), (('f', -1), coef, fn)], \n",
    "        'c': [(('b', -2), coef, fn), (('f', -2), coef, fn)],\n",
    "        'd': [(('b', -4), coef, fn), (('g', -1), coef, fn)],\n",
    "        'e': [(('f', -1), coef, fn)], \n",
    "        'f': [],\n",
    "        'g': [],\n",
    "        }\n",
    "T = 5000\n",
    "data,var_names,graph_gt = DataGenerator(sem, T=T, seed=0)\n",
    "graph_gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f3b161b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True ATE = 1.20\n"
     ]
    }
   ],
   "source": [
    "# Notice c does not depend on a if we intervene on b. Hence intervening a has no effect in this case. \n",
    "# This can be verified by changing the intervention values of variable a, which should have no impact on the ATE. \n",
    "# (see graph_gt above)\n",
    "\n",
    "t1='a' \n",
    "t2='b'\n",
    "target = 'c'\n",
    "target_var = var_names.index(target)\n",
    "\n",
    "intervention11 = 1*np.ones(T)\n",
    "intervention21 = 10*np.ones(T)\n",
    "intervention_data1,_,_ = DataGenerator(sem, T=T, seed=0,\n",
    "                        intervention={t1:intervention11, t2:intervention21})\n",
    "\n",
    "intervention12 = -0.*np.ones(T)\n",
    "intervention22 = -2.*np.ones(T)\n",
    "intervention_data2,_,_ = DataGenerator(sem, T=T, seed=0,\n",
    "                        intervention={t1:intervention12, t2:intervention22})\n",
    "\n",
    "\n",
    "\n",
    "true_effect = (intervention_data1[:,target_var] - intervention_data2[:,target_var]).mean()\n",
    "print(\"True ATE = %.2f\" %true_effect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1dc3f8b1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Estimated ATE: 1.19\n",
      "0.98s\n"
     ]
    }
   ],
   "source": [
    "\n",
    "tic = time.time()\n",
    "\n",
    "\n",
    "treatments = [define_treatments(t1, intervention11,intervention12),\\\n",
    "             define_treatments(t2, intervention21,intervention22)]\n",
    "# CausalInference_ = CausalInference(data, var_names, graph_gt,\\\n",
    "#         partial(MLPRegressor, hidden_layer_sizes=(100,100)) , False)\n",
    "CausalInference_ = CausalInference(data, var_names, graph_gt, LinearRegression , discrete=False)\n",
    "\n",
    "ate, y_treat,y_control = CausalInference_.ate(target, treatments)\n",
    "print(f'Estimated ATE: {ate:.2f}')\n",
    "toc = time.time()\n",
    "print(f'{toc-tic:.2f}s')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d24b9b45",
   "metadata": {},
   "source": [
    "### Conditional Average Treatement Effect (CATE)\n",
    "\n",
    "The data is generated using the following structural equation model:\n",
    "$$C = noise$$\n",
    "$$W = C + noise$$\n",
    "$$X = C*W + noise$$\n",
    "$$Y = C*X + noise$$\n",
    "\n",
    "We will treat C as the condition variable, X as the intervention variable, and Y as the target variable in our example below. The noise used in our example is sampled from the standard Gaussian distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "18ed77e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "T=5000\n",
    "data, var_names, graph_gt = ConditionalDataGenerator(T=T, data_type='time_series', seed=0, discrete=False)\n",
    "# var_names = ['C', 'W', 'X', 'Y']\n",
    "treatment_var='X'\n",
    "target = 'Y'\n",
    "target_idx = var_names.index(target)\n",
    "\n",
    "intervention1 = 0.1*np.ones(T, dtype=float)\n",
    "intervention_data1,_,_ = ConditionalDataGenerator(T=T, data_type='time_series',\\\n",
    "                                    seed=0, intervention={treatment_var:intervention1}, discrete=False)\n",
    "\n",
    "intervention2 = 0.9*np.ones(T, dtype=float)\n",
    "intervention_data2,_,_ = ConditionalDataGenerator(T=T, data_type='time_series',\\\n",
    "                                    seed=0, intervention={treatment_var:intervention2}, discrete=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "07824782",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Approx True CATE: -1.69\n",
      "Estimated CATE: -1.81\n",
      "Time taken: 5.17s\n"
     ]
    }
   ],
   "source": [
    "condition_state=2.1\n",
    "diff = np.abs(data[:,0] - condition_state)\n",
    "idx = np.argmin(diff)\n",
    "# assert diff[idx]<0.1, f'No observational data exists for the conditional variable close to {condition_state}'\n",
    "\n",
    "\n",
    "cate_gt = (intervention_data1[idx,target_idx] - intervention_data2[idx,target_idx])\n",
    "print(f'Approx True CATE: {cate_gt:.2f}')\n",
    "\n",
    "####\n",
    "treatments = define_treatments(treatment_var, intervention1,intervention2)\n",
    "conditions = {'var_name': 'C', 'condition_value': condition_state}\n",
    "\n",
    "tic = time.time()\n",
    "model = partial(MLPRegressor, hidden_layer_sizes=(100,100), max_iter=200)\n",
    "CausalInference_ = CausalInference(data, var_names, graph_gt, model, discrete=False)#\n",
    "\n",
    "cate = CausalInference_.cate(target, treatments, conditions, model)\n",
    "toc = time.time()\n",
    "print(f'Estimated CATE: {cate:.2f}')\n",
    "print(f'Time taken: {toc-tic:.2f}s')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc4a6d23",
   "metadata": {},
   "source": [
    "### Counterfactual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "04b799d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fn = lambda x:x\n",
    "coef = 0.1\n",
    "sem = {\n",
    "        'a': [], \n",
    "        'b': [(('a', -1), coef, fn), (('f', -1), coef, fn)], \n",
    "        'c': [(('b', 0), coef, fn), (('f', -2), coef, fn)],\n",
    "        'd': [(('b', -4), coef, fn), (('g', -1), coef, fn)],\n",
    "        'e': [(('f', -1), coef, fn)], \n",
    "        'f': [],\n",
    "        'g': [],\n",
    "        }\n",
    "T = 5000\n",
    "data,var_names,graph_gt = DataGenerator(sem, T=T, seed=0)\n",
    "# plot_graph(graph_gt, node_size=500)\n",
    "\n",
    "intervention={'b':np.array([10.]*10), 'e':np.array([-100.]*10)}\n",
    "target_var = 'c'\n",
    "\n",
    "sample, _, _= DataGenerator(sem, T=10, noise_fn=None,\\\n",
    "                                    intervention=None, discrete=False, nstates=10, seed=0)\n",
    "sample_intervened, _, _= DataGenerator(sem, T=10, noise_fn=None,\\\n",
    "                                    intervention=intervention, discrete=False, nstates=10, seed=0)\n",
    "\n",
    "sample=sample[-1] # use the last time step as our sample\n",
    "sample_intervened=sample_intervened[-1] # use the last time step as our sample and compute ground truth intervention\n",
    "var_orig = sample[var_names.index(target_var)]\n",
    "var_counterfactual_gt = sample_intervened[var_names.index(target_var)] # ground truth counterfactual\n",
    "# print(f'Original value of var {target_var}: {var_orig:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "59ce3061",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True counterfactual 1.16\n",
      "Estimated counterfactual 1.26\n"
     ]
    }
   ],
   "source": [
    "interventions = {name:float(val[0]) for name, val in intervention.items()}\n",
    "print(f'True counterfactual {var_counterfactual_gt:.2f}')\n",
    "\n",
    "# model = partial(MLPRegressor, hidden_layer_sizes=(100,100), max_iter=200)\n",
    "model = LinearRegression\n",
    "# model=None\n",
    "CausalInference_ = CausalInference(data, var_names, graph_gt, model, discrete=False)\n",
    "# model = None\n",
    "counterfactual_et = CausalInference_.counterfactual(sample, target_var, interventions, model)\n",
    "print(f'Estimated counterfactual {counterfactual_et:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61bf79d9",
   "metadata": {},
   "source": [
    "## Discrete Data\n",
    "\n",
    "The synthetic data generation procedure for the ATE and CATE examples below are identical to the procedure followed above for the continuous case, except that the generated data is discrete in the cases below.\n",
    "\n",
    "\n",
    "**Importantly**, when referring as discrete, we only treat the intervention variables as discrete in this case. The target variables and other variables are considered as continuous. Specifically, it doesn't make sense for the target variable to be discrete when we compute ATE or CATE, because it involves estimating the difference in states of the target variable, and for discrete variables, the difference between two states is not a meaningful quantity (as discrete states are symbolic in nature)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac9c1c47",
   "metadata": {},
   "source": [
    "### Average Treatment Effect (ATE)\n",
    " For this example, we will use synthetic data that has linear dependence among data variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6fd05b93",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "import pickle as pkl\n",
    "import time\n",
    "from functools import partial\n",
    "\n",
    "from causalai.data.data_generator import DataGenerator, ConditionalDataGenerator\n",
    "from causalai.models.time_series.causal_inference import CausalInference\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "def define_treatments(name, t,c):\n",
    "    treatment = dict(var_name=name,\n",
    "                    treatment_value=t,\n",
    "                    control_value=c)\n",
    "    return treatment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c8945c77",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': [],\n",
       " 'b': [('a', -1), ('f', -1)],\n",
       " 'c': [('b', -2), ('f', -2)],\n",
       " 'd': [('b', -4), ('b', -1), ('g', -1)],\n",
       " 'e': [('f', -1)],\n",
       " 'f': [],\n",
       " 'g': []}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fn = lambda x:x\n",
    "coef = 0.5\n",
    "sem = {\n",
    "        'a': [], \n",
    "        'b': [(('a', -1), coef, fn), (('f', -1), coef, fn)], \n",
    "        'c': [(('b', -2), coef, fn), (('f', -2), coef, fn)],\n",
    "        'd': [(('b', -4), coef, fn), (('b', -1), coef, fn), (('g', -1), coef, fn)],\n",
    "        'e': [(('f', -1), coef, fn)], \n",
    "        'f': [],\n",
    "        'g': [],\n",
    "        }\n",
    "T = 5000\n",
    "\n",
    "t1='a'\n",
    "t2='b'\n",
    "target = 'c'\n",
    "discrete = {name:True if name in [t1,t2] else False for name in sem.keys()}\n",
    "\n",
    "data,var_names,graph_gt = DataGenerator(sem, T=T, seed=0, discrete=discrete, nstates=10)\n",
    "graph_gt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10893fac",
   "metadata": {},
   "source": [
    "Notice how we specify the variable discrete above. We specify the intervention variables as discrete, while the others as continuous, as per our explanation above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7d5855e7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ground truth ATE = 0.82\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "target_var = var_names.index(target)\n",
    "\n",
    "# note that states can be [0,1,...,9], so the multiples below must be in this range\n",
    "intervention11 = 0*np.ones(T, dtype=int)\n",
    "intervention21 = 7*np.ones(T, dtype=int)\n",
    "intervention_data1,_,_ = DataGenerator(sem, T=T, seed=0,\n",
    "                            intervention={t1: intervention11, t2:intervention21}, discrete=discrete, nstates=10)\n",
    "\n",
    "intervention12 = 9*np.ones(T, dtype=int)\n",
    "intervention22 = 2*np.ones(T, dtype=int)\n",
    "intervention_data2,_,_ = DataGenerator(sem, T=T, seed=0,\n",
    "                            intervention={t1:intervention12, t2:intervention22}, discrete=discrete, nstates=10)\n",
    "\n",
    "true_effect = (intervention_data1[:,target_var] - intervention_data2[:,target_var]).mean()\n",
    "print(\"Ground truth ATE = %.2f\" %true_effect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8f5d1d21",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Estimated ATE: 0.78\n",
      "Time taken: 2.08s\n"
     ]
    }
   ],
   "source": [
    "\n",
    "tic = time.time()\n",
    "\n",
    "treatments = [define_treatments(t1, intervention11,intervention12),\\\n",
    "             define_treatments(t2, intervention21,intervention22)]\n",
    "model = partial(MLPRegressor, hidden_layer_sizes=(100,100), max_iter=200) # LinearRegression\n",
    "CausalInference_ = CausalInference(data, var_names, graph_gt, model, discrete=True)#\n",
    "o, y_treat,y_control = CausalInference_.ate(target, treatments)\n",
    "print(f'Estimated ATE: {o:.2f}')\n",
    "toc = time.time()\n",
    "print(f'Time taken: {toc-tic:.2f}s')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e1cad3",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "### CATE (conditional ATE)\n",
    "For this example we will use synthetic data that has non-linear dependence among data variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a1b9c6c2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': [],\n",
       " 'W': [('C', 0)],\n",
       " 'X': [('C', 0), ('W', 0)],\n",
       " 'Y': [('C', 0), ('X', 0)]}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "T=5000\n",
    "treatment_var='X'\n",
    "target = 'Y'\n",
    "target_idx = ['C', 'W', 'X', 'Y'].index(target)\n",
    "\n",
    "discrete = {name:True if name==treatment_var else False for name in ['C', 'W', 'X', 'Y']}\n",
    "data, var_names, graph_gt = ConditionalDataGenerator(T=T, data_type='time_series', seed=0, discrete=discrete, nstates=10)\n",
    "# var_names = ['C', 'W', 'X', 'Y']\n",
    "\n",
    "\n",
    "\n",
    "# note that states can be [0,1,...,9], so the multiples below must be in this range\n",
    "intervention1 = 9*np.ones(T, dtype=int)\n",
    "intervention_data1,_,_ = ConditionalDataGenerator(T=T, data_type='time_series',\\\n",
    "                                    seed=0, intervention={treatment_var:intervention1}, discrete=discrete, nstates=10)\n",
    "\n",
    "intervention2 = 1*np.ones(T, dtype=int)\n",
    "intervention_data2,_,_ = ConditionalDataGenerator(T=T, data_type='tabular',\\\n",
    "                                    seed=0, intervention={treatment_var:intervention2}, discrete=discrete, nstates=10)\n",
    "graph_gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2e6c3738",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Approx True CATE: 4.61\n",
      "Estimated CATE: 1.55\n",
      "Time taken: 7.01s\n"
     ]
    }
   ],
   "source": [
    "condition_var = 'C'\n",
    "condition_var_idx = var_names.index(condition_var)\n",
    "condition_state=0.5\n",
    "idx = np.argmin(np.abs(data[:,condition_var_idx]-condition_state))\n",
    "cate_gt = (intervention_data1[idx,target_idx] - intervention_data2[idx,target_idx]).mean()\n",
    "print(f'Approx True CATE: {cate_gt:.2f}')\n",
    "\n",
    "####\n",
    "treatments = define_treatments(treatment_var, intervention1,intervention2)\n",
    "conditions = {'var_name': condition_var, 'condition_value': condition_state}\n",
    "\n",
    "tic = time.time()\n",
    "model = partial(MLPRegressor, hidden_layer_sizes=(100,100), max_iter=200)\n",
    "CausalInference_ = CausalInference(data, var_names, graph_gt, model, discrete=True)#\n",
    "\n",
    "cate = CausalInference_.cate(target, treatments, conditions, model)\n",
    "toc = time.time()\n",
    "print(f'Estimated CATE: {cate:.2f}')\n",
    "print(f'Time taken: {toc-tic:.2f}s')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0f30847",
   "metadata": {},
   "source": [
    "### Counterfactual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "11306bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "fn = lambda x:x\n",
    "coef = 0.1\n",
    "sem = {\n",
    "        'a': [], \n",
    "        'b': [(('a', -1), coef, fn), (('f', -1), coef, fn)], \n",
    "        'c': [(('b', 0), coef, fn), (('f', -2), coef, fn)],\n",
    "        'd': [(('b', -4), coef, fn), (('g', -1), coef, fn)],\n",
    "        'e': [(('f', -1), coef, fn)], \n",
    "        'f': [],\n",
    "        'g': [],\n",
    "        }\n",
    "T = 5000\n",
    "\n",
    "intervention={'b':np.array([9]*10), 'e':np.array([0]*10)}\n",
    "target_var = 'c'\n",
    "discrete = {name:True if name in intervention.keys() else False for name in sem.keys()}\n",
    "\n",
    "data,var_names,graph_gt = DataGenerator(sem, T=T, seed=0, discrete=discrete)\n",
    "# plot_graph(graph_gt, node_size=500)\n",
    "\n",
    "sample, _, _= DataGenerator(sem, T=10, noise_fn=None,\\\n",
    "                                    intervention=None, discrete=discrete, nstates=10, seed=0)\n",
    "sample_intervened, _, _= DataGenerator(sem, T=10, noise_fn=None,\\\n",
    "                                    intervention=intervention, discrete=discrete, nstates=10, seed=0)\n",
    "\n",
    "sample=sample[-1] # use the last time step as our sample\n",
    "sample_intervened=sample_intervened[-1] # use the last time step as our sample and compute ground truth intervention\n",
    "var_orig = sample[var_names.index(target_var)]\n",
    "var_counterfactual_gt = sample_intervened[var_names.index(target_var)] # ground truth counterfactual\n",
    "# print(f'Original value of var {target_var}: {var_orig:.2f}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8e9dee4f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True counterfactual 0.30\n",
      "Estimated counterfactual 0.22\n"
     ]
    }
   ],
   "source": [
    "interventions = {name:val[0] for name, val in intervention.items()}\n",
    "print(f'True counterfactual {var_counterfactual_gt:.2f}')\n",
    "# model = partial(MLPRegressor, hidden_layer_sizes=(100,100), max_iter=200)\n",
    "model = LinearRegression\n",
    "# model=None\n",
    "CausalInference_ = CausalInference(data, var_names, graph_gt, model, discrete=True)\n",
    "# model = None\n",
    "counterfactual_et = CausalInference_.counterfactual(sample, target_var, interventions, model)\n",
    "print(f'Estimated counterfactual {counterfactual_et:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c93ce27",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
